from torch.utils import data
import torch


class Horse_Colic(data.Dataset):
    def __init__(self, root_path, train=True):
        if train:
            file = open(root_path+"train.txt")
        else:
            file = open(root_path+"test.txt")

        # 读取文件中的数据
        lines = file.readlines()
        lines = [list(map(float, item.split())) for item in lines]
        self.data = [item[:-1] for item in lines]
        self.label = [item[-1] for item in lines]
        self.len = len(self.label)

    def __getitem__(self, item):
        return torch.tensor(self.data[item]), torch.tensor(self.label[item])

    def __len__(self):
        return self.len

